{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2a6e764",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a7db2a4f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ac382344a734bdfa102d544631bca54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/578 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained('microsoft/deberta-v3-small')\n",
    "tokenizer = AutoTokenizer.from_pretrained('malaysia-ai/bpe-tokenizer')\n",
    "special_tokens_dict = {\"mask_token\": \"[MASK]\"}\n",
    "num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)\n",
    "config.vocab_size = len(tokenizer)\n",
    "config.max_position_embeddings = 4096"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c67c07df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DebertaV2Config {\n",
       "  \"_name_or_path\": \"microsoft/deberta-v3-small\",\n",
       "  \"attention_probs_dropout_prob\": 0.1,\n",
       "  \"hidden_act\": \"gelu\",\n",
       "  \"hidden_dropout_prob\": 0.1,\n",
       "  \"hidden_size\": 768,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": 3072,\n",
       "  \"layer_norm_eps\": 1e-07,\n",
       "  \"max_position_embeddings\": 4096,\n",
       "  \"max_relative_positions\": -1,\n",
       "  \"model_type\": \"deberta-v2\",\n",
       "  \"norm_rel_ebd\": \"layer_norm\",\n",
       "  \"num_attention_heads\": 12,\n",
       "  \"num_hidden_layers\": 6,\n",
       "  \"pad_token_id\": 0,\n",
       "  \"pooler_dropout\": 0,\n",
       "  \"pooler_hidden_act\": \"gelu\",\n",
       "  \"pooler_hidden_size\": 768,\n",
       "  \"pos_att_type\": [\n",
       "    \"p2c\",\n",
       "    \"c2p\"\n",
       "  ],\n",
       "  \"position_biased_input\": false,\n",
       "  \"position_buckets\": 256,\n",
       "  \"relative_attention\": true,\n",
       "  \"share_att_key\": true,\n",
       "  \"transformers_version\": \"4.35.0\",\n",
       "  \"type_vocab_size\": 0,\n",
       "  \"vocab_size\": 32001\n",
       "}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e89d330e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForMaskedLM.from_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "28480ee6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('debertav2-small/tokenizer_config.json',\n",
       " 'debertav2-small/special_tokens_map.json',\n",
       " 'debertav2-small/tokenizer.json')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save_pretrained('debertav2-small')\n",
    "tokenizer.save_pretrained('debertav2-small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "69e4d5b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 262M\r\n",
      "-rw-r--r-- 1 ubuntu ubuntu  866 Nov 15 06:35 config.json\r\n",
      "-rw-r--r-- 1 ubuntu ubuntu 260M Nov 15 06:35 model.safetensors\r\n",
      "-rw-r--r-- 1 ubuntu ubuntu  778 Nov 15 06:35 special_tokens_map.json\r\n",
      "-rw-r--r-- 1 ubuntu ubuntu 1.3M Nov 15 06:35 tokenizer.json\r\n",
      "-rw-r--r-- 1 ubuntu ubuntu 1.2K Nov 15 06:35 tokenizer_config.json\r\n"
     ]
    }
   ],
   "source": [
    "!ls -lh debertav2-small"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7ded1819",
   "metadata": {},
   "outputs": [],
   "source": [
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "from streaming import StreamingDataset\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "class UInt16(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.uint16)\n",
    "\n",
    "_encodings['uint16'] = UInt16\n",
    "\n",
    "class DatasetFixed(torch.utils.data.Dataset):\n",
    "    def __init__(self, local):\n",
    "        self.dataset = StreamingDataset(local=local)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.dataset[idx]\n",
    "        data.pop('token_type_ids', None)\n",
    "        for k in data.keys():\n",
    "            data[k] = data[k].astype(np.int64)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "train_dataset = DatasetFixed(local='tokenized-512')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "461ad7c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm_probability=0.15,\n",
    "    pad_to_multiple_of=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d04f1ad4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    }
   ],
   "source": [
    "batch = [train_dataset[i] for i in range(3)]\n",
    "b = data_collator(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38540f61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MaskedLMOutput(loss=tensor(10.5336, grad_fn=<NllLossBackward0>), logits=tensor([[[ 0.0000,  0.0178,  0.7783,  ...,  0.0152, -0.3278,  0.8059],\n",
       "         [ 0.0000,  0.0064, -1.1060,  ...,  0.5021, -0.4350, -0.1890],\n",
       "         [ 0.0000,  0.7313,  0.6659,  ...,  0.3666, -0.5069,  0.3751],\n",
       "         ...,\n",
       "         [ 0.0000,  0.2071,  0.5059,  ...,  0.4080, -0.3686,  0.3104],\n",
       "         [ 0.0000,  0.6307,  0.4583,  ...,  0.5337, -0.1931,  0.9158],\n",
       "         [ 0.0000,  0.1186,  0.3102,  ...,  1.1024, -0.1440,  0.7481]],\n",
       "\n",
       "        [[ 0.0000,  0.3048,  0.8347,  ..., -1.0490,  0.1279,  0.3387],\n",
       "         [ 0.0000, -0.1180,  0.0993,  ...,  0.2921,  0.7044, -0.0275],\n",
       "         [ 0.0000,  0.0527,  0.9007,  ...,  0.3695,  0.3636,  0.5029],\n",
       "         ...,\n",
       "         [ 0.0000,  1.0467, -0.2307,  ...,  0.2358, -0.0311,  0.1480],\n",
       "         [ 0.0000,  0.7975, -0.3127,  ...,  0.0608, -0.4451, -0.5585],\n",
       "         [ 0.0000, -0.1052, -0.2830,  ...,  0.6167, -0.4749,  0.4427]],\n",
       "\n",
       "        [[ 0.0000,  0.0829,  0.6659,  ...,  0.3530, -0.5330,  0.5139],\n",
       "         [ 0.0000,  1.0076,  0.4548,  ...,  0.1174,  0.3041,  0.1728],\n",
       "         [ 0.0000, -1.0286,  0.9499,  ...,  1.2319,  0.2825, -0.0674],\n",
       "         ...,\n",
       "         [ 0.0000, -0.8834,  1.0734,  ...,  0.9293, -0.5939, -0.1793],\n",
       "         [ 0.0000,  0.9625,  0.1057,  ...,  0.1323, -0.1824,  0.7873],\n",
       "         [ 0.0000,  0.1886, -0.5654,  ..., -0.2442, -0.0658,  0.2067]]],\n",
       "       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(**b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a7061de",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
